4 Main Results
Interpret and summarize the prediction and stability results.
Evaluate pipeline on test data.
Summarize test set prediction and/or interpretability results.
Insert narrative here.
TODO: Ana - add template tables with interpretation
What is the real-world question? This could be hypothesis-driven or discovery-based.
Why is this question interesting and important? What are the implications of better understanding this data?
Is there any prior work or background information related to this question? Describe the premise and put your questions and analysis in the domain context.
Briefly describe how this question can be answered in the context of a model or analysis.
Outline the rest of the report/analysis.
Insert narrative here.
What is the data under investigation? Provide a brief overview/description of the data.
How is this data relevant to the problem of interest? In other words, make the link between the data and the domain problem.
Insert narrative here.
How was the data collected or generated (including details on the experimental design)? Be as transparent as possible so that conclusions made from this data are not misinterpreted down the road.
Describe what the data represents in reality, i.e., make the link between the data and reality. Also be sure to describe ways in which the data cannot model reality.
Where is the data stored, and how can it be accessed by others (if applicable)?
Insert narrative here.
Split the data into a training, validation, and test set.
Decide on the proportion of data in each split.
Decide on the “how” to split the data (e.g., random sampling, stratified sampling, etc.), and explain why this is a reasonable way to split the data.
Provide summary statistics and/or figures of the three data sets to illustrate how similar (or different) they are.
Insert narrative here.
TODO: Ana - Provide some code to do the data splitting.
# TODO: pick more interesting datasets
X <- iris %>% dplyr::select(-Species)
y <- iris$Species
data_df <- dplyr::bind_cols(.y = y, X)
splits <- rsample::initial_split(data_df)
train_df <- rsample::training(splits)
valid_df <- rsample::testing(splits)
Xtrain <- X # placeholders for now
Xvalid <- X
Xtest <- X
ytrain <- y
yvalid <- y
ytest <- yWhat steps were taken to clean the data? More importantly, why was the data cleaned in this way?
Discuss all inconsistencies, problems, oddities in the data (e.g., missing data, errors in data, outliers, etc.).
Record your preprocessing steps in a way such that if someone else were to reproduce your analysis, they could easily replicate and understand your steps.
It can be helpful to include relevant plots that explain/justify the choices that were made when cleaning the data.
If more than one preprocessing pipeline is reasonable, examine the impacts of these alternative preprocessing pipelines on the final data results.
Again, be as transparent as possible. This allows others to make their own educated decisions on how best to preprocess the data.
Insert narrative here.
The main goal of this section is to give the reader a feel for what the data “looks like” at a basic level.
Provide plots that summarize the data and perhaps even plots that convey some smaller findings which ultimately motivate the main findings.
Provide additional plots representing remaining oddities after pre-processing if applicable.
Add summary statistics in accompanying tables (or in figures) for quick comparisons.
Insert narrative here.
#> Number of samples: 150
#> Number of features: 4
#> Number of NAs in training y: 0
#> Number of NAs in training X: 0
#> Number of columns in training X with NAs: 0
#> Number of constant columns in training X: 0
data_types(Xtrain = Xtrain, ytrain = ytrain)dt_ls <- data_summary(Xtrain = Xtrain, ytrain = ytrain, digits = 2, sigfig = F)
for (dt_name in names(dt_ls)) {
subchunkify(dt_ls[[dt_name]], i = chunk_idx, other_args = "results='asis'")
chunk_idx <- chunk_idx + 1
}# plot X distribution
plot_X_distribution(Xtrain, "density")# plot y distribution
plot_y_distribution(ytrain, "bar")# correlation heatmap
plotCorHeatmap(X = Xtrain, cor.type = "pearson", clust = TRUE, text.size = 8)# pair plots
col_ids <- 1:min(ncol(Xtrain), 6)
plotPairs(data = Xtrain, columns = col_ids,
color = ytrain, color.label = "y")caret::featurePlot(x = Xtrain,
y = ytrain,
plot = if (is.factor(ytrain)) "box" else "scatter",
# strip = strip.custom(par.strip.text = list(cex = .7)),
scales = list(x = list(relation = "free"),
y = list(relation = "free")))# dimension reduction plots
plotPCA(X = Xtrain, npcs = 3, color = ytrain, color.label = "y",
center = T, scale = FALSE)$plotFor inspiration: Shiny App
Discuss the prediction methods under consideration, and explain why these methods were chosen.
Discuss the accuracy metrics under consideration, and explain why these metrics were chosen.
Note: there should be multiple methods and metrics under consideration to paint a more holistic picture of the data. At least one method should be a baseline, common approach that may not be optimal for the problem setting, but serves as a helpful comparison.
Insert narrative here.
Carry out the prediction pipeline, outlined above.
- Fit prediction methods on training data.
- Evaluate prediction methods on validation data.
- Compare results, and filter out poor models.
Insert narrative here.
# TODO: add code for tuning parameters
mod_recipe <- recipes::recipe(.y ~., data = splits)
# for classification
rf_model <- parsnip::rand_forest() %>%
parsnip::set_args(mtry = tune::tune()) %>%
parsnip::set_engine("ranger", importance = "impurity") %>%
parsnip::set_mode("classification")
rf_grid <- tidyr::crossing(mtry = 1:4)
svm_model <- parsnip::svm_rbf() %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("classification")
knn_model <- parsnip::nearest_neighbor() %>%
parsnip::set_args(neighbors = tune(), weight_func = tune()) %>%
parsnip::set_engine("kknn") %>%
parsnip::set_mode("classification")
# models <- workflowsets::workflow_set(
# preproc = list(Base = mod_recipe),
# models = list(RF = rf_model, SVM = svm_model, KNN = knn_model),
# cross = TRUE
# ) %>%
# workflowsets::option_add(grid = rf_grid, id = "Base_RF")
# model_fits <- workflowsets::workflow_map(
# object = models,
# fn = "tune_grid"
# )
model_list <- list(RF = list(model = rf_model,
grid = rf_grid),
SVM = list(model = svm_model,
grid = NULL),
KNN = list(model = knn_model,
grid = 4))
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]$model
grid <- model_list[[model_name]]$grid
if (!is.null(grid)) {
mod_fit <- workflows::workflow() %>%
workflows::add_recipe(mod_recipe) %>%
workflows::add_model(mod)
best_params <- mod_fit %>%
tune::tune_grid(resamples = rsample::vfold_cv(train_df),
grid = grid) %>%
tune::select_best(metric = "accuracy")
mod_fit <- mod_fit %>%
tune::finalize_workflow(best_params) %>%
tune::last_fit(splits)
} else {
mod_fit <- workflows::workflow() %>%
workflows::add_recipe(mod_recipe) %>%
workflows::add_model(mod) %>%
tune::last_fit(splits)
}
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- mod_fit %>%
tune::collect_predictions()
model_errs[[model_name]] <- mod_fit %>%
tune::collect_metrics()
model_vimps[[model_name]] <- tryCatch({
# model-specific variable importance
mod_fit %>%
workflows::extract_fit_parsnip() %>%
vip::vi()
}, error = function(e) {
# model-agnostic permutation variable importance
mod_fit %>%
workflows::extract_fit_parsnip() %>%
vip::vi(method = "permute", train = train_df, target = ".y",
feature_names = setdiff(colnames(train_df), ".y"),
pred_wrapper = predict, metric = "accuracy")
})
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")# how to do cross validation
trcontrol <- caret::trainControl(
method = "cv",
number = 5,
classProbs = if (is.factor(ytrain)) TRUE else FALSE,
summaryFunction = caret::defaultSummary,
allowParallel = FALSE,
verboseIter = FALSE
)
response <- "raw"
model_list <- list(
ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain)),
ncol(Xtrain) / 3,
length.out = 3),
splitrule = "gini",
min.node.size = 1),
importance = "impurity",
num.threads = 1),
xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
max_depth = c(3, 6),
colsample_bytree = 0.33,
eta = c(0.1, 0.3),
gamma = 0,
min_child_weight = 1,
subsample = 0.6),
nthread = 1)
)
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]
if (identical(mod, list())) {
mod <- NULL
}
mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain),
y = ytrain,
trControl = trcontrol,
method = model_name),
mod))
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- predict(mod_fit, as.data.frame(Xvalid),
type = response)
model_errs[[model_name]] <- caret::postResample(
pred = model_preds[[model_name]], obs = yvalid
)
model_vimps[[model_name]] <- caret::varImp(mod_fit)
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- purrr::map_dfr(model_vimps,
~.x[["importance"]] %>%
tibble::rownames_to_column("variable"),
.id = "model")library(h2o)
h2o.init(nthreads = 1)
iris.hex <- as.h2o(iris)
splits <- h2o.splitFrame(data = iris.hex,
ratios = c(0.8))
train_df <- splits[[1]]
valid_df <- splits[[2]]
model_list <- list(randomForest = list(ntrees = 500),
xgboost = list())
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]
if (identical(mod, list())) {
mod <- NULL
}
mod_fit <- do.call(paste0("h2o.", model_name),
args = c(list(x = colnames(Xtrain),
y = "Species",
training_frame = train_df,
model_id = model_name),
mod))
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- h2o.predict(mod_fit, valid_df)
model_errs[[model_name]] <- h2o.performance(mod_fit, valid_df)
model_vimps[[model_name]] <- h2o.varimp(mod_fit)
}
model_preds <- purrr::map_dfr(model_preds, ~attr(.x, "data"), .id = "model")
model_errs <- purrr::map_dfr(
model_errs,
function(err) {
rm_objs <- c("model", "model_checksum", "frame", "frame_checksum",
"description", "scoring_time", "predictions")
simChef:::simplify_tibble(simChef:::list_to_tibble_row(
err@metrics[setdiff(names(err@metrics), rm_objs)]
))
},
.id = "model"
)
model_vimps <- dplyr::bind_rows(model_vimps, .id = "model")Taking the prediction methods that pass the prediction check, perform stability analysis.
- Specify and justify the appropriate data perturbation(s).
- Re-fit the prediction methods on these perturbed data sets.
- Evaluate prediction methods on validation data.
- Assess stability across the data perturbations as well as across the various methods.
- Filter out poor models where necessary and interpret stability results.
Insert narrative here.
TODO: Ana - Provide some example code here (for both fitting and visualizing results) so that the practitioner can easily input their data and models. Add something like James’ slide. A few pictures with possible data perturbation schemes (separate from parameter tuning). Data splitting vs sampling of observations. Cross-validation-ish scheme and fixed training/validation setup (bootstrapping, sub-sampling, stratified-sampling). Add parameter to include/exclude certain code chunks.
For the models that pass the prediction and stability checks,
- Extract the important features in the predictive models that are stable across both data and model perturbations. Determining the importance of a feature can be method dependent.
Insert narrative here.
TODO: Tiffany - Provide some example code here (for both fitting and visualizing results) so that the practitioner can easily input their data and models.
prettyDT(model_vimps, digits = 2, sigfig = F, caption = "Variable Importances")# bar plot
vip::vip(model_vimps,
num_features = 10,
geom = "col") +
prettyGGplotTheme()# scatter plot
plt <- model_vimps %>%
tidyr::pivot_wider(names_from = "model", values_from = "Importance") %>%
plotPairs(columns = which(!(colnames(.) %in% "Variable"))) +
ggplot2::theme_bw()
plotly::ggplotly(plt)TODO
Interpret and summarize the prediction and stability results.
Evaluate pipeline on test data.
Summarize test set prediction and/or interpretability results.
Insert narrative here.
TODO: Ana - add template tables with interpretation
Move beyond the global prediction accuracy metrics and dive deeper into individual-level predictions for the validation and/or test set, i.e., provide a more “local” analysis.
- Examine any points that had poor predictions.
- Examine differences between prediction methods.
Insert narrative here.
TODO: Tiffany - Add examples with interesting observations of prediction accuracy metrics so the user knows what to look for.
model_preds %>%
tidyr::pivot_wider(names_from = "model", values_from = ".pred_setosa",
id_cols = c("id", ".row")) %>%
plotPairs(columns = which(!(colnames(.) %in% c("id", ".row"))))Reiterate main findings, note any caveats, and clearly translate findings/analysis back to the domain problem context.
Insert narrative here.